import os
os.chdir("../")
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import sys
import random
import numpy as np

import torch
import torch.optim as optim
import torch.nn.functional as F
from models.r2d2 import OBLR2D2Agent
from utils.memory import Memory, LocalBuffer, OBLMemory, OBLLocalBuffer, MIOBLMemory, MIOBLLocalBuffer
from tensorboardX import SummaryWriter

from models.r2d2_config import initial_exploration, batch_size, update_target, log_interval, eval_argmax, eval_interval, device, replay_memory_capacity, lr, sequence_length, local_mini_batch, use_mi_loss
from utils.pbmaze_config import multi_env_config, multi_iql_env_config
from multi_phone_booth_collab_maze import PBCMaze
from pbcmaze_belief_model import ReceiverBeliefModel, MultiPBSenderBeliefModel
from collections import deque
noise_str = str(multi_env_config["booth_locs"][-1][-1])
RESULT_PATH = "results/multi_booths/" + noise_str[2] + "/"
MODEL_PATH = "trained_models/multi_booths/" + noise_str[2] + "/"
NUM_RUNS = 4

def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)


def evaluate(eval_env, a0_agent, a1_agent):
    done = False

    score = 0
    mi_score = 0
    steps = 0
    a0_reward = None
    a1_reward = None
    obs, state = eval_env.reset()

    a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
    a1_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))

    with torch.no_grad():
        while not done:
            # Agent 0's turn
            a0_obs = torch.Tensor(eval_env.get_obs(0)).to(device)
            a0_policy, a0_action, a0_next_hidden = a0_agent.get_action(a0_obs, a0_hidden, argmax = True)
            _, mi = eval_env.calculate_mi_reward(a0_policy.squeeze().detach().numpy(), 0 , None)
            mi_score += mi
            a0_reward, done, info = eval_env.step(0, a0_action, a0_policy.squeeze().detach().numpy())
            a0_hidden = a0_next_hidden
            # Agent 1's turn
            a1_obs = torch.Tensor(eval_env.get_obs(1)).to(device)
            a1_policy, a1_action, a1_next_hidden = a1_agent.get_action(a1_obs, a1_hidden, argmax = True)
            a1_reward, done, info = eval_env.step(1, a1_action)
            a1_hidden = a1_next_hidden
            score += a0_reward + a1_reward
    return score, info, mi_score

def evaluate_policy(a0_agent, a1_agent):
    with torch.no_grad():
        env = PBCMaze(env_args=multi_iql_env_config)
        env.reset()
        a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
        a1_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
        while(env.agent0_loc[0] != env.booth_loc or env.agent1_loc[0] != env.receiver_booth_loc):
            a0_obs = torch.Tensor(env.get_obs(0)).to(device)
            _, _, a0_hidden = a0_agent.get_action(a0_obs, a0_hidden, argmax = True)
            a1_obs = torch.Tensor(env.get_obs(1)).to(device)
            _, _, a1_hidden = a1_agent.get_action(a1_obs, a1_hidden, argmax = True)
            env.step(0, 1)
            env.step(1, 0)
        a0_obs = torch.Tensor(env.get_obs(0)).to(device)
        a0_policy, _, _ = a0_agent.get_action(a0_obs, a0_hidden, argmax = True)
        print(a0_policy)

def main():
    sender_time_to_booth_result = []
    receiver_time_to_booth_result = []
    booth_visits_result = []
    reward_result = []
    eval_reward_result = []
    eval_mi_reward_result = []
    running_reward_result = []
    runnning_eval_reward_result = []
    for run_idx in range(NUM_RUNS):
        print("Run: " + str(run_idx + 1))
        # Set seed
        set_seed(run_idx)

        # Env
        num_episodes = 12000
        num_episodes_for_mi_training = 12000
        stage_2_training = False
        use_iql_for_stage_2 = False

        env = PBCMaze(env_args=multi_env_config)
        env.reset()
        eval_env = PBCMaze(env_args=multi_env_config)
        eval_env.reset()
        eval_env.load_env_config(env.save_env_config())
        eval_env.use_mi_shaping = False
        eval_env.use_intermediate_reward = False

        """
        Agent 0 obs: ((channel, width, height), goal feature)
        Agent 1 obs: ((channel, width, height), communication token)
        """
        a0_input_shape  = env.get_obs_size(0)
        a1_input_shape = env.get_obs_size(1)
        a0_num_actions = 7
        a1_num_actions = 5

        receiver_pi_0 = [0.2, 0.2, 0.2, 0.2, 0.2]
        sender_pi_0 = [1/7, 1/7, 1/7, 1/7, 1/7, 1/7, 1/7]
        rb_model = ReceiverBeliefModel(receiver_pi_0, env)
        sb_model = MultiPBSenderBeliefModel(sender_pi_0, env)
        if(use_mi_loss):
            a0_agent = OBLR2D2Agent(a0_input_shape, a0_num_actions, Memory(replay_memory_capacity), LocalBuffer(), MIOBLMemory(replay_memory_capacity), MIOBLLocalBuffer(), lr, batch_size, device, 0, rb_model, use_mi_loss, multi_pb = True)
        else:
            a0_agent = OBLR2D2Agent(a0_input_shape, a0_num_actions, Memory(replay_memory_capacity), LocalBuffer(), OBLMemory(replay_memory_capacity), OBLLocalBuffer(), lr, batch_size, device, 0, rb_model, use_mi_loss, multi_pb = True)
        a1_agent = OBLR2D2Agent(a1_input_shape, a1_num_actions, Memory(replay_memory_capacity), LocalBuffer(), OBLMemory(replay_memory_capacity), OBLLocalBuffer(), lr, batch_size, device, 1, sb_model, multi_pb = True)

        writer = SummaryWriter('logs')

        running_score = 0
        running_eval_score = 0
        epsilon = 1.0
        steps = 0
        loss = 0
        per_run_sender_time_to_booth_list = []
        per_run_receiver_time_to_booth_list = []
        per_run_booth_visits_list = []
        per_run_reward = []
        per_run_eval_reward = []
        per_run_mi_reward = []
        per_run_running_reward = []
        per_run_running_eval_reward = []
        for e in range(num_episodes):
            done = False

            score = 0
            a0_reward = None
            a1_reward = None
            obs, state = env.reset()

            a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
            a1_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
            a1_next_hidden = None

            while not done:
                steps += 1

                # Agent 0's turn
                a0_obs = torch.Tensor(env.get_obs(0)).to(device)
                a0_policy, a0_action, a0_next_hidden = a0_agent.get_action(a0_obs, a0_hidden)
                # OBL Sampling
                a0_curr_env_config = env.save_env_config()
                if(stage_2_training == False  or (stage_2_training and use_iql_for_stage_2 == False)):
                    a0_agent.obl_sampling(a0_hidden, a0_next_hidden, a0_policy.squeeze().detach().numpy(), a0_action, a0_curr_env_config, a1_agent, a1_next_hidden if a1_next_hidden != None else a1_hidden)
                a0_reward, done, info = env.step(0, a0_action, a0_policy.squeeze().detach().numpy())

                # Add to agent 1's IQL buffer
                if(stage_2_training and use_iql_for_stage_2):
                    if(a1_reward != None):
                        # Add to agent 1's buffer
                        mask = 0 if done else 1
                        next_a1_obs = torch.Tensor(env.get_obs(1)).to(device)
                        a1_agent.iql_buffer.push(a1_obs, next_a1_obs, a1_action, a1_reward + a0_reward, mask, a1_hidden)
                        if len(a1_agent.iql_buffer.memory) == local_mini_batch:
                            a1_agent.push_to_iql_memory()

                    # Agent 1's IQL learning
                    if steps > initial_exploration and len(a1_agent.iql_memory) > batch_size:
                        loss, td_error = a1_agent.train_iql_model()

                        if steps % update_target == 0:
                            a1_agent.update_target_model()

                # Update after a0 has taken an action
                if(a1_next_hidden != None):
                    a1_hidden = a1_next_hidden

                if len(a0_agent.local_buffer.memory) == local_mini_batch:
                    a0_agent.push_to_memory()

                if steps > initial_exploration and len(a0_agent.memory) > batch_size:
                    if(stage_2_training == False or (stage_2_training and use_iql_for_stage_2 == False)):
                        loss, td_error = a0_agent.train_model(obl = True, use_mi_loss = use_mi_loss)

                    if steps % update_target == 0:
                        a0_agent.update_target_model()
                # Update belief
                a1_agent.belief_model.update_belief(comm_token = env.comm_token)

                # Agent 1's turn
                a1_obs = torch.Tensor(env.get_obs(1)).to(device)
                a1_policy, a1_action, a1_next_hidden = a1_agent.get_action(a1_obs, a1_hidden)
                a1_curr_env_config = env.save_env_config()
                if(stage_2_training == False  or (stage_2_training and use_iql_for_stage_2 == False)):
                    a1_agent.obl_sampling(a1_hidden, a1_next_hidden, a1_policy.squeeze().detach().numpy(), a1_action, a1_curr_env_config, a0_agent, a0_next_hidden)
                a1_reward, done, info = env.step(1, a1_action)

                # Add to agent 0's buffer
                if(stage_2_training and use_iql_for_stage_2):
                    mask = 0 if done else 1
                    next_a0_obs = torch.Tensor(env.get_obs(0)).to(device)
                    a0_agent.iql_buffer.push(a0_obs, next_a0_obs, a0_action, a0_reward + a1_reward, mask, a0_hidden)
                    if len(a0_agent.iql_buffer.memory) == local_mini_batch:
                        a0_agent.push_to_iql_memory()

                    if(done):
                        # Need to add to a1's buffer
                        next_a1_obs = torch.Tensor(env.get_obs(1)).to(device)
                        a1_agent.iql_buffer.push(a1_obs, next_a1_obs, a1_action, a1_reward, mask, a1_hidden)
                        if len(a1_agent.iql_buffer.memory) == local_mini_batch:
                            a1_agent.push_to_iql_memory()

                    # Agent 1's IQL learning
                    if steps > initial_exploration and len(a0_agent.iql_memory) > batch_size:
                        loss, td_error = a0_agent.train_iql_model()
                        if steps % update_target == 0:
                            a0_agent.update_target_model()

                a0_hidden = a0_next_hidden

                if len(a1_agent.local_buffer.memory) == local_mini_batch:
                    a1_agent.push_to_memory()

                if steps > initial_exploration and len(a1_agent.memory) > batch_size:
                    if(stage_2_training == False or (stage_2_training and use_iql_for_stage_2 == False)):
                        loss, td_error = a1_agent.train_model(obl = True)

                    if steps % update_target == 0:
                        a1_agent.update_target_model()
                # Update belief
                a0_agent.belief_model.update_belief()

                score += a0_reward + a1_reward

            running_score = 0.99 * running_score + 0.01 * score
            # Steps to phone booth
            if(eval_argmax):
                if e % eval_interval == 0:
                    eval_score, info, eval_mi_score = evaluate(eval_env, a0_agent, a1_agent)
                    booth_visits = info["sender_booth_visits"]
                    #evaluate_policy(a0_agent, a1_agent)
                    running_eval_score = 0.99 * running_eval_score + 0.01 * eval_score
                    sender_time_to_pb = info["sender_time_to_booth"]
                    receiver_time_to_pb = info["receiver_time_to_booth"]
                    per_run_booth_visits_list.append(booth_visits)
                    per_run_sender_time_to_booth_list.append(sender_time_to_pb)
                    per_run_receiver_time_to_booth_list.append(receiver_time_to_pb)
                    per_run_eval_reward.append(eval_score)
                    per_run_mi_reward.append(eval_mi_score)
                    per_run_running_eval_reward.append(running_eval_score)
                    print('Run {} | {} episode | score: {:.2f} | reward sum: {:.2f} | SenderToPB: {:.2f} | SenderBoothVisits: {} | ReceiverToPB: {:.2f}'.format(
                        run_idx + 1, e, running_eval_score, eval_score, sender_time_to_pb, booth_visits, receiver_time_to_pb))
                    sys.stdout.flush()
            else:
                sender_time_to_pb = info["sender_time_to_booth"]
                receiver_time_to_pb = info["receiver_time_to_booth"]
                per_run_sender_time_to_booth_list.append(sender_time_to_pb)
                per_run_receiver_time_to_booth_list.append(receiver_time_to_pb)
                per_run_reward.append(score)
                per_run_running_reward.append(running_score)
                if e % log_interval == 0:
                    print('Run {} | {} episode | score: {:.2f} | reward sum: {:.2f} | SenderToPB: {:.2f} | ReceiverToPB: {:.2f}'.format(
                        run_idx + 1, e, running_score, score, sender_time_to_pb, receiver_time_to_pb))
                    writer.add_scalar('log/score', float(running_score), e)
                    writer.add_scalar('log/loss', float(loss), e)
                    sys.stdout.flush()


            # Reset belief
            a0_agent.belief_model.reset_belief()
            a1_agent.belief_model.reset_belief()

            # turn off mi training
            if((e + 1) >= num_episodes_for_mi_training):
                env.turn_off_mi_training()

        sender_time_to_booth_result.append(per_run_sender_time_to_booth_list)
        receiver_time_to_booth_result.append(per_run_receiver_time_to_booth_list)
        booth_visits_result.append(per_run_booth_visits_list)
        if(eval_argmax):
            eval_reward_result.append(per_run_eval_reward)
            eval_mi_reward_result.append(per_run_eval_reward)
            runnning_eval_reward_result.append(per_run_running_eval_reward)
        else:
            reward_result.append(per_run_reward)
            running_reward_result.append(per_run_running_reward)

    # Save results
    if not os.path.exists(RESULT_PATH):
        os.makedirs(RESULT_PATH)
    if not os.path.exists(MODEL_PATH):
        os.makedirs(MODEL_PATH)

    if(multi_env_config['use_mi_shaping'] or use_mi_loss):
        sender_result_filename = "obl_sender_time_to_pb" + ("_mi_log2" if multi_env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "")  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        receiver_result_filename = "obl_receiver_time_to_pb" + ("_mi_log2" if multi_env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "")  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        bv_result_filename = "obl_booth_visits" + ("_mi_log2" if multi_env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "")  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        reward_result_filename = "obl_reward" + ("_mi_log2" if multi_env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "")  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        mi_reward_result_filename = "obl_mi_reward" + ("_mi_log2" if multi_env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "")  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        running_reward_result_filename = "obl_running_reward" + ("_mi_log2" if multi_env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "")  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        sender_model_path = MODEL_PATH + "obl_sender_model " + ("_mi_log2" if multi_env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "")  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "")
        receiver_model_path = MODEL_PATH + "obl_receiver_model " + ("_mi_log2" if multi_env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "")  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "")

    elif(multi_env_config['use_intermediate_reward']):
        sender_result_filename = "obl_sender_time_to_pb" + "_ir" + ("_argmax" if eval_argmax else "") + ".npy"
        receiver_result_filename = "obl_receiver_time_to_pb" + "_ir" + ("_argmax" if eval_argmax else "") + ".npy"
        bv_result_filename = "obl_booth_visits" + "_ir" + ("_argmax" if eval_argmax else "") + ".npy"
        reward_result_filename = "obl_reward" + "_ir"  + ("_argmax" if eval_argmax else "") + ".npy"
        running_reward_result_filename = "obl_running_reward" + "_ir"  + ("_argmax" if eval_argmax else "") + ".npy"
        sender_model_path = MODEL_PATH + "obl_sender_model " + "_ir" + ("_argmax" if eval_argmax else "")
        receiver_model_path = MODEL_PATH + "obl_receiver_model " + "_ir" + ("_argmax" if eval_argmax else "")

    else:
        sender_result_filename = "obl_sender_time_to_pb"  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        receiver_result_filename = "obl_receiver_time_to_pb"  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        bv_result_filename = "obl_booth_visits"  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        reward_result_filename = "obl_reward"  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        running_reward_result_filename = "obl_running_reward"  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        sender_model_path = MODEL_PATH + "obl_sender_model "  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "")
        receiver_model_path = MODEL_PATH + "obl_receiver_model "  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "")

    if(eval_argmax):
        np.save(RESULT_PATH + reward_result_filename, np.array(eval_reward_result))
        if(multi_env_config['use_mi_shaping'] or use_mi_loss):
            np.save(RESULT_PATH + mi_reward_result_filename, np.array(eval_mi_reward_result))
        np.save(RESULT_PATH + running_reward_result_filename, np.array(runnning_eval_reward_result))
    else:
        np.save(RESULT_PATH + reward_result_filename, np.array(reward_result))
        np.save(RESULT_PATH + running_reward_result_filename, np.array(running_reward_result))

    np.save(RESULT_PATH + sender_result_filename, np.array(sender_time_to_booth_result))
    np.save(RESULT_PATH + receiver_result_filename, np.array(receiver_time_to_booth_result))
    np.save(RESULT_PATH + bv_result_filename, np.array(booth_visits_result))

    # Save model
    a0_agent.save_model(sender_model_path)
    a1_agent.save_model(receiver_model_path)

if __name__=="__main__":
    main()
